{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "J8BjXeY2UcnN"
   },
   "source": [
    "# Using Generation Models for Summarization\n",
    "This notebook demonstrates a simple way of using Cohere's generation models to summarize text.\n",
    "<img src=\"https://github.com/cohere-ai/notebooks/raw/main/notebooks/images/summarization.png\"\n",
    "    style=\"width:100%; max-width:600px\" alt=\"provided with the right prompt, a language model can generate multiple candidate summaries\" />\n",
    "\n",
    "We will use a simple prompt that includes two examples and a task description:\n",
    "\n",
    "`\"<input phrase>\". In summary: \"<summary>\"`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "EqdDiovEBRLw",
    "outputId": "b72dc9ea-d73f-4291-bc59-7d3b1d8490cd"
   },
   "outputs": [],
   "source": [
    "# Let's first install Cohere's python SDK\n",
    "# TODO: upgrade to \"cohere>5\"",
"!pip install \"cohere<5\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "Hf-3hMa7BOkU"
   },
   "outputs": [],
   "source": [
    "import cohere\n",
    "import time\n",
    "import pandas as pd\n",
    "# Paste your API key here. Remember to not share it publicly \n",
    "api_key = ''\n",
    "co = cohere.Client(api_key)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bOzfT9gJFfDl"
   },
   "source": [
    "\n",
    "Our prompt is geared for paraphrasing to simplify an input sentence. It contains two examples. The sentence we want it to summarize is:\n",
    "\n",
    "**Killer whales have a diverse diet, although individual populations often specialize in particular types of prey.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "douHOfl0JJq-",
    "outputId": "5df6eaba-3210-4e55-d6a9-7b40ae1d30b8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\"The killer whale or orca (Orcinus orca) is a toothed whale belonging to the oceanic dolphin family, of which it is the largest member\"\n",
      "In summary: \"The killer whale or orca is the largest type of dolphin\"\n",
      "---\n",
      "\"It is recognizable by its black-and-white patterned body\" \n",
      "In summary:\"Its body has a black and white pattern\"\n",
      "---\n",
      "\"Killer whales have a diverse diet, although individual populations often specialize in particular types of prey\" \n",
      "In summary:\"\n"
     ]
    }
   ],
   "source": [
    "prompt = '''\"The killer whale or orca (Orcinus orca) is a toothed whale belonging to the oceanic dolphin family, of which it is the largest member\"\n",
    "In summary: \"The killer whale or orca is the largest type of dolphin\"\n",
    "---\n",
    "\"It is recognizable by its black-and-white patterned body\" \n",
    "In summary:\"Its body has a black and white pattern\"\n",
    "---\n",
    "\"Killer whales have a diverse diet, although individual populations often specialize in particular types of prey\" \n",
    "In summary:\"'''\n",
    "print(prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "m3PO3HH6FqzQ"
   },
   "source": [
    "We get several completions from the model via the API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "id": "s6wsSRKaBcai"
   },
   "outputs": [],
   "source": [
    "n_generations = 5\n",
    "\n",
    "prediction = co.generate(\n",
    "    model='large',\n",
    "    prompt=prompt,\n",
    "    return_likelihoods = 'GENERATION',\n",
    "    stop_sequences=['\"'],\n",
    "    max_tokens=50,\n",
    "    temperature=0.7,\n",
    "    num_generations=n_generations,\n",
    "    k=0,\n",
    "    p=0.75)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get list of generations\n",
    "gens = []\n",
    "likelihoods = []\n",
    "for gen in prediction.generations:\n",
    "    gens.append(gen.text)\n",
    "    \n",
    "    sum_likelihood = 0\n",
    "    for t in gen.token_likelihoods:\n",
    "        sum_likelihood += t.likelihood\n",
    "    # Get sum of likelihoods\n",
    "    likelihoods.append(sum_likelihood)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 274
    },
    "id": "2sLwp-3ABiVj",
    "outputId": "ddda4d8b-9f95-454f-f812-8b5d940936d2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Candidate summaries for the sentence: \n",
      "\"Killer whales have a diverse diet, although individual populations often specialize in particular types of prey.\"\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>generation</th>\n",
       "      <th>likelihood</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Killer whales have a diverse diet\"</td>\n",
       "      <td>-3.208850</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Its diet is diverse\"</td>\n",
       "      <td>-3.487236</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Their diet is diverse\"</td>\n",
       "      <td>-3.761171</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Different populations have different diets\"</td>\n",
       "      <td>-6.415764</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Their diet consists of a variety of marine life\"</td>\n",
       "      <td>-11.764865</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         generation  likelihood\n",
       "0                Killer whales have a diverse diet\"   -3.208850\n",
       "1                              Its diet is diverse\"   -3.487236\n",
       "2                            Their diet is diverse\"   -3.761171\n",
       "3       Different populations have different diets\"   -6.415764\n",
       "4  Their diet consists of a variety of marine life\"  -11.764865"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.options.display.max_colwidth = 200\n",
    "# Create a dataframe for the generated sentences and their likelihood scores\n",
    "df = pd.DataFrame({'generation':gens, 'likelihood': likelihoods})\n",
    "# Drop duplicates\n",
    "df = df.drop_duplicates(subset=['generation'])\n",
    "# Sort by highest sum likelihood\n",
    "df = df.sort_values('likelihood', ascending=False, ignore_index=True)\n",
    "print('Candidate summaries for the sentence: \\n\"Killer whales have a diverse diet, although individual populations often specialize in particular types of prey.\"')\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "M3HM7go_tebY"
   },
   "source": [
    "In a lot of cases, better generations can be reached by creating multiple generations then ranking and filtering them. In this case we're ranking the generations by their average likelihoods. \n",
    "\n",
    "## Hyperparameters\n",
    "It's worth spending some time learning the various hyperparameters of the generation endpoint. For example, [temperature](https://docs.cohere.ai/temperature-wiki) tunes the degree of randomness in the generations. Other parameters include [top-k and top-p](https://docs.cohere.ai/token-picking) as well as `frequency_penalty` and `presence_penalty` which can reduce the amount of repetition in the output of the model. See the [API reference of the generate endpoint](https://docs.cohere.ai/generate-reference) for more details on all the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Basic Summarization Notebook.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
